{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.path.append('..')\n",
    "from pb_diff_envs.environment.static.comp_rm2d_env import ComposedRM2DEnv\n",
    "import numpy as np\n",
    "from pb_diff_envs.environment.static.rand_maze2d_env import RandMaze2DEnv\n",
    "\n",
    "maze_size = np.array([5, 5])\n",
    "w1 = np.array([[0.95824706, 4.33524396],\n",
    "        [0.96091322, 2.7401753 ],\n",
    "        [1.40722512, 1.26307468],\n",
    "        [2.37623081, 2.54474058],\n",
    "        [2.65520526, 0.71410657],\n",
    "        [3.07530283, 3.9151338 ],])\n",
    "h1 = np.array([0.5, 0.5])[None,].repeat(len(w1), axis=0)\n",
    "\n",
    "w2 = np.array([[0.97086047, 1.20158635],\n",
    "       [2.47202437, 4.00066585],\n",
    "       [4.04977404, 2.85560483]])\n",
    "\n",
    "h2 = np.array([0.7, 0.7])[None,].repeat(3, axis=0)\n",
    "nwc = np.array( [len(w1), len(w2)] )\n",
    "\n",
    "robot_config = dict(maze_size=maze_size, min_to_wall_dist=0.01, collision_eps=0.02)\n",
    "wl = [w1, w2]\n",
    "hl = [h1, h2]\n",
    "env = ComposedRM2DEnv(wl, hl, robot_config, renderer_config=dict(num_walls_c=nwc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env = RandMaze2DEnv(w1, h1, robot_config, renderer_config={})\n",
    "# env = RandMaze2DEnv(w2, h2, robot_config, renderer_config={})\n",
    "# env.render()\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Vis many trajs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.kuka_utils_luo import SolutionInterp\n",
    "from tqdm import tqdm\n",
    "\n",
    "sol_interp = SolutionInterp(density=8)\n",
    "env.load(GUI=True) # False\n",
    "prev_pose = env.get_robot_free_pose()\n",
    "env.robot.set_config(prev_pose)\n",
    "solutions = []\n",
    "# n_traj = 25\n",
    "n_traj = 5\n",
    "for i in tqdm(range(n_traj)):\n",
    "    solution, _ = env.sample_1_episode(prev_pose)\n",
    "    solution = np.array(solution)\n",
    "    print(solution.shape)\n",
    "    prev_pose = solution[-1]\n",
    "    solutions.append(solution)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "solutions[0].shape\n",
    "solutions[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "solutions_np = np.concatenate(solutions, axis=0)\n",
    "trajs = [None,] * len(solutions)\n",
    "# print(trajs.shape)\n",
    "# env.render_1_traj(solution)\n",
    "for i in tqdm(range(n_traj)):\n",
    "    trajs[i] = sol_interp(solutions[i])\n",
    "    print(trajs[i].shape)\n",
    "tmp = env.render_composite('./luotest_c.png', trajs)\n",
    "print(tmp.shape)\n",
    "plt.clf()\n",
    "plt.figure(dpi=200)\n",
    "plt.imshow(tmp); plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch2ss",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
